import csv
import json
import re
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from urllib.error import URLError
from urllib.error import HTTPError
from urllib.parse import urlencode
from urllib.request import Request, urlopen
import urllib.request

import pandas as pd


GEO_RE = re.compile(r"^[A-Z]{2}[A-Z0-9]{1,3}$")  # NUTS1/2/3-like and similar (e.g., UKC, EL30, BG311)
EXCLUDE_PREFIXES = {"IL", "RU", "UA"}  # not covered by this Eurostat dataset


@dataclass(frozen=True)
class EurostatQuery:
    dataset: str
    unit: str
    year: int


def _unravel(flat: int, sizes: list[int]) -> list[int]:
    idx = []
    for size in reversed(sizes):
        idx.append(flat % size)
        flat //= size
    return list(reversed(idx))


def _codes_by_pos(dim: dict) -> list[str]:
    index = (dim.get("category", {}) or {}).get("index", {}) or {}
    # index is mapping code -> position
    codes = [None] * len(index)
    for code, pos in index.items():
        if isinstance(pos, int) and 0 <= pos < len(codes):
            codes[pos] = code
    return [c for c in codes if c is not None]


def fetch_batch(query: EurostatQuery, geos: list[str], opener: urllib.request.OpenerDirector) -> dict[str, Any]:
    base = "https://ec.europa.eu/eurostat/api/dissemination/statistics/1.0/data/"
    params = {"unit": query.unit, "time": str(query.year), "geo": ";".join(geos)}
    url = base + query.dataset + "?" + urlencode(params)
    req = Request(url, headers={"User-Agent": "ess-research/1.0"})

    backoff = 1.0
    for attempt in range(6):
        try:
            with opener.open(req, timeout=60) as resp:
                raw = resp.read()
            return json.loads(raw.decode("utf-8"))
        except HTTPError as e:
            # Treat server errors / oversized requests as non-retryable here so caller can split the batch.
            if e.code in {400, 413, 414, 431, 500, 502, 503}:
                raise
            if attempt == 5:
                raise
            time.sleep(backoff)
            backoff = min(backoff * 2, 15)
        except (URLError, TimeoutError) as e:
            if attempt == 5:
                raise
            time.sleep(backoff)
            backoff = min(backoff * 2, 15)
    raise RuntimeError("Unexpected retry fallthrough")


def fetch_geos_strict(
    query: EurostatQuery, geos: list[str], opener: urllib.request.OpenerDirector
) -> tuple[list[str], list[dict[str, Any]]]:
    """
    Eurostat endpoint behaves strictly: if any requested geo is invalid, response may be empty.
    We recursively split batches to isolate invalid geos.
    Returns (valid_geos, records) where records are parsed rows with possible missing values.
    """
    if not geos:
        return [], []

    try:
        data = fetch_batch(query, geos, opener)
    except Exception:
        if len(geos) == 1:
            return [], []
        mid = len(geos) // 2
        left_ok, left_rows = fetch_geos_strict(query, geos[:mid], opener)
        right_ok, right_rows = fetch_geos_strict(query, geos[mid:], opener)
        return left_ok + right_ok, left_rows + right_rows
    geo_index = ((data.get("dimension", {}) or {}).get("geo", {}) or {}).get("category", {}).get("index", {}) or {}
    if len(geo_index) == 0:
        if len(geos) == 1:
            return [], []
        mid = len(geos) // 2
        left_ok, left_rows = fetch_geos_strict(query, geos[:mid], opener)
        right_ok, right_rows = fetch_geos_strict(query, geos[mid:], opener)
        return left_ok + right_ok, left_rows + right_rows

    # Parse rows
    ids = data.get("id") or []
    sizes = data.get("size") or []
    dims = data.get("dimension") or {}
    codes = {dim_id: _codes_by_pos(dims.get(dim_id, {})) for dim_id in ids}

    values = data.get("value") or {}
    rows = []
    for flat_s, v in values.items():
        flat = int(flat_s)
        coords = _unravel(flat, sizes)
        coord_map = {ids[i]: codes[ids[i]][coords[i]] for i in range(len(ids))}
        rows.append(
            {
                "geo": coord_map.get("geo", ""),
                "year": int(coord_map.get("time", query.year)),
                "unit": coord_map.get("unit", query.unit),
                "value": v,
            }
        )

    valid_geos = list(geo_index.keys())
    return valid_geos, rows


def chunk_list(xs: list[str], n: int) -> list[list[str]]:
    return [xs[i : i + n] for i in range(0, len(xs), n)]


def main() -> None:
    root = Path(__file__).resolve().parents[1]
    out_dir = root / "data_external"
    out_dir.mkdir(parents=True, exist_ok=True)

    import argparse

    ap = argparse.ArgumentParser()
    ap.add_argument("--year", type=int, required=True)
    ap.add_argument("--unit", required=True, choices=["PC_HH", "PC_HH_IACC"])
    ap.add_argument("--chunk", type=int, default=30)
    ap.add_argument("--max-geos", type=int, default=0)
    ap.add_argument("--use-system-proxy", action="store_true", help="Use Windows system proxy (default: no proxy).")
    ap.add_argument("--regunit", default="1,2", help="Comma-separated allowed regunit levels (default: 1,2)")
    ap.add_argument(
        "--use-all-ess-regions",
        action="store_true",
        help="Query all ESS regions (after regunit filter), ignoring --year filter; useful for fetching baseline years.",
    )
    ap.add_argument(
        "--regions-from-csv",
        default="",
        help="Optional CSV path containing a 'region' column to define the geo list (e.g., existing Eurostat region-year file).",
    )
    args = ap.parse_args()

    ess = pd.read_parquet(root / "outputs" / "ess_r8_r11_min.parquet")
    allowed_regunit = sorted({int(x.strip()) for x in str(args.regunit).split(",") if x.strip().isdigit()})
    if not allowed_regunit:
        raise SystemExit("Invalid --regunit; expected comma-separated ints like 1,2")
    ess = ess[ess["regunit"].isin(allowed_regunit)].copy()
    ess["region"] = ess["region"].astype(str).str.upper().str.strip()

    # Mapping region -> (cntry, regunit) from ESS (for later merge)
    region_map = (
        ess[["cntry", "regunit", "region"]]
        .dropna()
        .drop_duplicates()
        .groupby("region", as_index=False)
        .agg({"cntry": "first", "regunit": "first"})
    )
    region_to_cntry = dict(zip(region_map["region"], region_map["cntry"]))
    region_to_regunit = dict(zip(region_map["region"], region_map["regunit"]))

    # Clean list of geo codes to query
    # Default: only regions appearing in the target survey year (keeps fetches small).
    regions_source = None
    if args.regions_from_csv:
        try:
            src = pd.read_csv(Path(args.regions_from_csv))
            regions_source = src["region"].astype(str).str.upper().str.strip().tolist()
        except Exception:
            regions_source = None

    if regions_source is not None:
        regions_all = sorted({r for r in regions_source if GEO_RE.match(r) and r not in {"99999"}})
    elif args.use_all_ess_regions:
        regions_all = sorted({r for r in ess["region"].dropna().unique().tolist() if GEO_RE.match(r) and r not in {"99999"}})
    else:
        ess_y = ess[ess["survey_year"] == args.year]
        regions_all = sorted({r for r in ess_y["region"].dropna().unique().tolist() if GEO_RE.match(r) and r not in {"99999"}})
    regions = [r for r in regions_all if r[:2] not in EXCLUDE_PREFIXES]
    if args.max_geos and args.max_geos > 0:
        regions = regions[: args.max_geos]

    dataset = "isoc_r_broad_h"

    all_valid = set()

    query = EurostatQuery(dataset=dataset, unit=args.unit, year=args.year)
    opener = urllib.request.build_opener() if args.use_system_proxy else urllib.request.build_opener(urllib.request.ProxyHandler({}))

    out_long = out_dir / f"broadband_region_year_eurostat_isoc_r_broad_h_long_{args.year}_{args.unit}.csv"
    already = set()
    if out_long.exists():
        try:
            prev = pd.read_csv(out_long, usecols=["region"])
            already = set(prev["region"].astype(str).str.upper().tolist())
        except Exception:
            already = set()

    regions_todo = [r for r in regions if r not in already]
    batches = chunk_list(regions_todo, args.chunk)

    for i, batch in enumerate(batches, start=1):
        valid_geos, rows = fetch_geos_strict(query, batch, opener)
        all_valid.update(valid_geos)
        out_rows = []
        for r in rows:
            geo = r["geo"]
            out_rows.append(
                {
                    "cntry": region_to_cntry.get(geo, ""),
                    "regunit": region_to_regunit.get(geo, ""),
                    "region": geo,
                    "year": r["year"],
                    "unit": r["unit"],
                    "value": r["value"],
                    "source": "Eurostat API (isoc_r_broad_h)",
                    "notes": "",
                }
            )

        if out_rows:
            write_header = not out_long.exists()
            with out_long.open("a", newline="", encoding="utf-8") as f:
                w = csv.DictWriter(
                    f, fieldnames=["cntry", "regunit", "region", "year", "unit", "value", "source", "notes"]
                )
                if write_header:
                    w.writeheader()
                w.writerows(out_rows)

        time.sleep(0.15)
        if i % 3 == 0 or i == len(batches):
            done = min(i * args.chunk, len(regions_todo))
            print(f"progress {done}/{len(regions_todo)} geos, valid_geos={len(all_valid)}", flush=True)

    print(f"Wrote/Updated: {out_long}")
    print(f"Valid geo codes found: {len(all_valid)} / requested {len(regions)} (todo={len(regions_todo)})")


if __name__ == "__main__":
    main()
